package cn.lyh.mySpring;

import cn.lyh.mySpring.Handler.ResponseBodyHandler;
import cn.lyh.mySpring.annotation.*;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import org.apache.log4j.Logger;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.net.URL;
import java.util.*;

/***
 *dispatcherServlet
 * @author lyh
 */
public class MyDispatcherServlet extends HttpServlet {
    /***配置***/
    private Properties contextConfig = new Properties();
    /***扫描的类名列表****/
    private List<String> classNames = new ArrayList<>();
    /***ioc容器 存放实例****/
    private Map<String, Object> ioc = new HashMap<>();
    /***url映射****/
    private Map<String, Method> handlerMapping = new HashMap<>();
    private static Logger logger = Logger.getLogger(MyDispatcherServlet.class);
    /***返回处理器****/
    private ResponseBodyHandler responseBodyHandler;

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doDispatcherServlet(req, resp);
    }


    /****
     * 加载启动
     * @param config
     * @throws ServletException
     */
    @Override
    public void init(ServletConfig config) throws ServletException {
        String contextConfigLocation = config.getInitParameter("contextConfigLocation");
        try {
            initMyDispatcherServlet(contextConfigLocation);
        } catch (Exception e) {
            e.printStackTrace();
            throw new ServletException(e.getMessage());
        }
    }


    /***
     * url请求映射到具体方法
     * @param request
     * @param response
     */
    private void doDispatcherServlet(HttpServletRequest request, HttpServletResponse response) {
        invoke(request, response);
    }


    private void invoke(HttpServletRequest request, HttpServletResponse response) {
        String queryUrl = request.getRequestURI();
        queryUrl = queryUrl.replaceAll("/+", "/");
        Method method = handlerMapping.get(queryUrl);
        if (null == method) {
            PrintWriter pw = null;
            try {
                response.setStatus(404);
                logger.debug("request fail(404)： " + request.getRequestURI());
                pw = response.getWriter();
                pw.print("404    not find       ->        " + request.getRequestURI());
                pw.flush();
            } catch (IOException e) {
                e.printStackTrace();
            } finally {
                pw.close();
            }
        } else {
            //todo method parameters need  to deal
            Object[] paramValues = getMethodParamAndValue(request, response, method);
            try {
                String controllerClassName = toFirstWordLower(method.getDeclaringClass().getSimpleName());
                Object object = method.invoke(ioc.get(controllerClassName), paramValues);
                if (object != null) {
                    if (method.isAnnotationPresent(MyResponseBody.class)) {
                        response.setHeader("content-type", "application/json;charset=UTF-8");
                        if (null == responseBodyHandler) {
                            object = JSONObject.toJSONString(object, SerializerFeature.WriteMapNullValue);
                        } else {
                            object = responseBodyHandler.equals(object);
                        }
                    }
                    response.getWriter().print(object);
                    logger.debug("request-> " + request.getRequestURI() + ", response success ->" + response.getStatus());
                }
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            } catch (InvocationTargetException e) {
                e.printStackTrace();
            } catch (IOException e) {
                e.printStackTrace();
            }

        }
    }

    /****
     * @MyRequestParam
     * 参数解析 复制
     * @注意: 参数解析暂不完整 int float long double boolean string
     *        实体接收暂不支持
     * @param request
     * @param response
     * @param method
     * @return
     */
    private Object[] getMethodParamAndValue(HttpServletRequest request, HttpServletResponse response, Method method) {
        Parameter[] parameters = method.getParameters();
        Object[] paramValues = new Object[parameters.length];
        for (int i = 0; i < parameters.length; i++) {

            if (ServletRequest.class.isAssignableFrom(parameters[i].getType())) {
                paramValues[i] = request;
            } else if (ServletResponse.class.isAssignableFrom(parameters[i].getType())) {
                paramValues[i] = response;
            } else {
                String bindingValue = parameters[i].getName();
                if (parameters[i].isAnnotationPresent(MyRequestParam.class)) {
                    bindingValue = parameters[i].getAnnotation(MyRequestParam.class).value();
                }
                String paramValue = request.getParameter(bindingValue);
                paramValues[i] = paramValue;
                if (paramValue != null) {
                    if (Integer.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Integer.parseInt(paramValue);
                    } else if (Float.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Float.parseFloat(paramValue);
                    } else if (Double.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Double.parseDouble(paramValue);
                    } else if (Long.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Long.parseLong(paramValue);
                    } else if (Boolean.class.isAssignableFrom(parameters[i].getType())) {
                        paramValues[i] = Boolean.parseBoolean(paramValue);
                    }
                }
            }
        }
        return paramValues;
    }


    /****
     * 初始化
     * @param contextConfigLocation
     * @throws Exception
     */
    private void initMyDispatcherServlet(String contextConfigLocation) throws Exception {
        logger.info("-----------------------------mySpring init start-----------------------------------------");
        logger.debug("doLoadConfig:" + contextConfigLocation);
        //加载配置
        doLoadConfig(contextConfigLocation);
        //扫描 包扫描
        logger.debug("scan:" + contextConfig.getProperty("scan.package"));
        doScanner(contextConfig.getProperty("scan.package"));
        //创建实体类、ioc
        doInstance();
        //注入 di
        doAutowired();
        //url 映射
        initHandlerMapping();

    }

    /***
     * 注入
     */
    private void doAutowired() {
        if (ioc.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Object object = entry.getValue();
            Field[] fields = object.getClass().getDeclaredFields();
            for (Field filed : fields) {
                if (filed.isAnnotationPresent(MyAutowired.class)) {
                    MyAutowired myAutowired = filed.getAnnotation(MyAutowired.class);
                    String key = filed.getType().getName();
                    String val = myAutowired.value();
                    if (val != null && "".equals(val.trim())) {
                        key = val.trim();
                    }
                    filed.setAccessible(true);
                    try {
                        filed.set(object, ioc.get(key));
                    } catch (IllegalAccessException e) {
                        e.printStackTrace();
                    }
                } else {
                    continue;
                }
            }
        }
    }

    /***
     * 初始化HandlerMapper
     */
    private void initHandlerMapping() {
        if (ioc.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Object object = entry.getValue();
            Class<?> clazz = object.getClass();
            if (clazz.isAnnotationPresent(MyController.class)) {
                Method[] methods = clazz.getDeclaredMethods();
                MyRequestMapping requestMapping = clazz.getAnnotation(MyRequestMapping.class);
                String crlRequstMapping = requestMapping.value() == null ? "" : requestMapping.value();
                for (Method method : methods) {
                    if (method.isAnnotationPresent(MyRequestMapping.class)) {
                        String url = ("/" + crlRequstMapping + "/" + method.getAnnotation(MyRequestMapping.class).value()).replaceAll("/+", "/");
                        // check request url must only
                        if (handlerMapping.containsKey(url)) {
                            logger.error("mapping request url:" + url + "is already exist! request url must only");
                            new Exception("mapping:" + url + "is already exist!");
                        }
                        handlerMapping.put(url, method);
                        logger.debug("mapping: " + url);
                    } else {
                        continue;
                    }
                }
            }

        }
    }

    /***
     * 加载配置文件
     * @param contextConfigLocation
     * @throws Exception
     */
    private void doLoadConfig(String contextConfigLocation) throws Exception {
        InputStream is = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
        if (is == null) {
            logger.error("config:" + contextConfigLocation + " not exist");
            throw new Exception("config:" + contextConfigLocation + " not exist");
        } else {
            try {
                contextConfig.load(is);
            } catch (IOException e) {
                e.printStackTrace();
            } finally {
                //关流
                if (null != is) {
                    try {
                        is.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }

    /****
     * 包扫描
     * @param packageName
     * @throws Exception
     */
    private void doScanner(String packageName) throws Exception {
        if (packageName == null || packageName.length() == 0) {
            throw new Exception("init scan is empty");
        }

        URL url = this.getClass().getClassLoader().getResource("/" + packageName.replaceAll("\\.", "/"));
        if (null != url) {
            File dir = new File(url.getFile());
            for (File file : dir.listFiles()) {
                if (file.isDirectory()) {
                    //递归读取包
                    doScanner(packageName + "." + file.getName());
                } else {
                    String className = packageName + "." + file.getName().replace(".class", "");
                    logger.debug("scan class find:" + className);
                    classNames.add(className);
                }
            }
        }

    }

    /****
     * ioc实例化
     */
    private void doInstance() {
        if (classNames.isEmpty()) {
            return;
        }
        for (String className : classNames) {
            try {
                // @MyController instance
                Class<?> clazz = Class.forName(className);
                if (clazz.isAnnotationPresent(MyController.class)) {
                    logger.debug("MyController instance: " + clazz.getName());
                    ioc.put(toFirstWordLower(clazz.getSimpleName()), clazz.newInstance());
                } else if (clazz.isAnnotationPresent(MyService.class)) {
                    //todo @MyService instance
                    // 1 以自己本类或者用户自定义别名为key
                    Object newInstance = clazz.newInstance();
                    String key = toFirstWordLower(clazz.getSimpleName());
                    logger.debug("MyService instance: " + clazz.getName());
                    MyService service = clazz.getAnnotation(MyService.class);
                    String value = service.value().trim();
                    if (!"".equals(value)) {
                        key = value;
                    }
                    if (!ioc.containsKey(key)) {
                        ioc.put(key, newInstance);
                    } else {
                        logger.error("MyService instance: " + service.value() + "  is  exist");
                        throw new Exception("MyService instance: " + service.value() + "  is  exist");
                    }
                    //2 以所继承的接口为 key
                    Class<?>[] interfaces = clazz.getInterfaces();
                    for (Class<?> interClazz : interfaces) {
                        ioc.put(interClazz.getName(), clazz.newInstance());
                    }

                } else if (clazz.isAnnotationPresent(MyResponseAdvice.class)) {
                    if (clazz.isAssignableFrom(ResponseBodyHandler.class)) {
                        if (null != responseBodyHandler) {
                            continue;
                        }
                        responseBodyHandler = (ResponseBodyHandler) clazz.newInstance();
                    } else {
                        logger.error("class+'" + clazz.getName() + "' must implement ResponseBodyHandler");
                        throw new Exception("class+'" + clazz.getName() + "' must implement ResponseBodyHandler");
                    }
                } else {
                    continue;
                }


            } catch (Exception e) {
                e.printStackTrace();
                continue;
            }
        }
    }

    /**
     * 把字符串的首字母小写
     *
     * @param name
     * @return
     */
    private String toFirstWordLower(String name) {
        char[] charArray = name.toCharArray();
        charArray[0] += 32;
        return String.valueOf(charArray);
    }

}
